

import math
import copy
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp

from typing import Sequence
from einops import rearrange
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.runner.base_module import BaseModule
from mmcv.cnn.bricks.transformer import (
    BaseTransformerLayer,
    TransformerLayerSequence,
    build_transformer_layer_sequence
)
from mmcv.cnn import (
    build_activation_layer,
    build_conv_layer,
    build_norm_layer,
    xavier_init
)
from mmcv.cnn.bricks.registry import (
    ATTENTION,
    TRANSFORMER_LAYER,
    TRANSFORMER_LAYER_SEQUENCE
)
from mmcv.utils import (
    ConfigDict,
    build_from_cfg,
    deprecated_api_warning,
    to_2tuple
)
from mmdet.models.utils.builder import TRANSFORMER


@TRANSFORMER.register_module()
class AFSAfusion(BaseModule):
    """Implements the DETR transformer.
    Args:
        encoder (`mmcv.ConfigDict` | Dict): Config of
            TransformerEncoder. Defaults to None.
        decoder ((`mmcv.ConfigDict` | Dict)): Config of
            TransformerDecoder. Defaults to None
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Defaults to None.
    """

    def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False):
        super(AFSAfusion, self).__init__(init_cfg=init_cfg)
        if encoder is not None:
            self.encoder = build_transformer_layer_sequence(encoder)
        else:
            self.encoder = None
        self.decoder = build_transformer_layer_sequence(decoder)
        self.embed_dims = self.decoder.embed_dims
        self.cross = cross

    def init_weights(self):
        # follow the official DETR to init parameters
        for m in self.modules():
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                xavier_init(m, distribution='uniform')
        self._is_init = True

    def forward(self, x, x_img, bev_pos_embed, rv_pos_embed, attn_masks=None, reg_branch=None):

        bs, c, h, w = x.shape
        _, c, hi, wi = x_img.shape
        bev_memory = rearrange(x, "bs c h w -> (h w) bs c") # [bs, n, c, h, w] -> [n*h*w, bs, c]
        rv_memory = rearrange(x_img, "(bs v) c h w -> (v h w) bs c", bs=bs)
        bev_pos_embed = bev_pos_embed.unsqueeze(1).repeat(1, bs, 1) # [bs, n, c, h, w] -> [n*h*w, bs, c]
        rv_pos_embed = rearrange(rv_pos_embed, "(bs v) h w c -> (v h w) bs c", bs=bs)
        
        mask =  rv_memory.new_zeros(bs, rv_memory.shape[0]) # [bs, n, h, w] -> [bs, n*h*w]

        # out_dec: [num_layers, num_query, bs, dim]
        out_x = self.decoder(
            query=bev_memory,
            key=rv_memory,
            value=rv_memory,
            key_pos=rv_pos_embed,
            query_pos=bev_pos_embed,
            key_padding_mask=mask,
            attn_masks=[None],
            reg_branch=reg_branch,
            )
        out_x = out_x.transpose(1, 2)
        out_x = rearrange(out_x, "n bs (h w) c -> n bs c h w", h=h)

        out_x_img = self.decoder(
            query=rv_memory,
            key=bev_memory,
            value=bev_memory,
            key_pos=bev_pos_embed,
            query_pos=rv_pos_embed,
            key_padding_mask=mask,
            attn_masks=[None],
            reg_branch=reg_branch,
            )
        out_x_img = out_x_img.transpose(1, 2)
        out_x_img = rearrange(out_x_img, "n bs (v hi wi) c -> n (bs v) c hi wi", hi=hi, wi=wi)
        return  out_x, out_x_img

